#include "THP.h"

// In cases like DataLoader, if a worker process die due to bus error/segfault
// or just hang, the main process, if implemented with
// multiprocessing.queue.SimpleQueue, will hang waiting for data. This is
// difficult to avoid on PyTorch side as it can be caused by limited shm, or
// other libraries users call in the workers. The following methods is an effort
// to do our best provide some error message to users when such unfortunate
// events happen.

// TODO: The following don't work on Windows. Specifically, sigaction, waitid
// calls ,and SIGCHLD handler. Currently, dummy implementations are provided
// for Windows.

#ifndef _WIN32

#include <sys/wait.h>
#include <map>
#include <set>
#include <atomic>
#include <signal.h>


// Critical signal handlers should be registered on worker processes before
// doing work.
// The handler will raise default handler so that the kill information will be
// retrieved from main process.
// Python handle is _set_worker_signal_handlers().
#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG)                       \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx)                 \
{                                                                             \
  auto _w = write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char));\
  (void)_w;                                                                   \
  struct sigaction sa;                                                        \
  sa.sa_handler = SIG_DFL;                                                    \
  sa.sa_flags = 0;                                                            \
  if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGNAL, &sa, NULL) != 0) {   \
    _exit(EXIT_FAILURE);                                                      \
  } else {                                                                    \
    raise(SIGNAL);                                                            \
  }                                                                           \
}

// signal(2) is really not portable. So use sigaction.
// http://man7.org/linux/man-pages/man2/signal.2.html
static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
{
  struct sigaction sa;
  sa.sa_sigaction = handler;
  sa.sa_flags = SA_RESTART|SA_SIGINFO|SA_NOCLDSTOP|SA_NODEFER;
  if (sigemptyset(&sa.sa_mask) != 0 || sigaction(signal, &sa, old_sa_ptr) != 0) {
    std::ostringstream oss;
    oss << "An error occurred while setting handler for " << strsignal(signal) << ".";
    throw std::runtime_error(oss.str());
  }
}

SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. "
  "This might be caused by insufficient shared memory (shm).\n");
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");

PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
  HANDLE_TH_ERRORS
  setSignalHandler(SIGBUS, &handler_SIGBUS, NULL);
  setSignalHandler(SIGSEGV, &handler_SIGSEGV, NULL);
  Py_RETURN_TRUE;
  END_HANDLE_TH_ERRORS
}

static std::map<int64_t, std::set<pid_t>> worker_pids = {};

PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
  HANDLE_TH_ERRORS
  int error;
  std::set<pid_t> *pid_set;
  pid_t pid;
  siginfo_t infop;

  // Only check the pids we care about
  for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) {
    pid_set = &(it->second);
    for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
      pid = *pid_it;
      // Use waitid rather than waitpid so that we can set NOWAIT, and that Python
      // and other handlers can get whatever info they want about the child.
      infop.si_pid = 0;
      error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
      // ignore errors and case with no waitable child
      if (error < 0 || infop.si_pid == 0)
        continue;
      if (infop.si_code == CLD_EXITED && infop.si_status != 0) {  // exit with error
        std::ostringstream oss;
        oss << "DataLoader worker (pid " << pid << ") exited unexpectedly "
            << "with exit code " << infop.si_status << ".";
        // This is necessary. Otherwise, the runtime error will kill the other
        // workers, and trigger this again.
        pid_set->clear();
        throw std::runtime_error(oss.str());
      }  else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) {  // killed by signal
        std::ostringstream oss;
        oss << "DataLoader worker (pid " << pid << ") is killed by signal: "
            << strsignal(infop.si_status) << ".";
        // This is necessary. Otherwise, the runtime error will kill the other
        // workers, and trigger this again.
        pid_set->clear();
        throw std::runtime_error(oss.str());
      }
    }
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
// of pids we are interested in.
PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *args) {
  HANDLE_TH_ERRORS
  Py_ssize_t num_args = args ? (Py_ssize_t) PyTuple_Size(args) : 0;
  THPUtils_assert(num_args == 2, "_update_worker_pids expectes exactly 2 arguments.");
  int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
  THPUtils_assert(worker_pids.find(key) == worker_pids.end(), "_update_worker_pids "
        "should be called only once for each DataLoader.");
  PyObject *child_pids = PyTuple_GET_ITEM(args, 1);
  THPUtils_assert(PyTuple_Check(child_pids), "_update_worker_pids "
        "expects a tuple for child_pids, but got %s.", THPUtils_typename(child_pids));

  std::set<pid_t> pids_set = {};
  auto size = PyTuple_GET_SIZE(child_pids);
  for (int idx = 0; idx < size; idx++) {
    PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
    pids_set.insert((pid_t) THPUtils_unpackLong(obj));
  }

  worker_pids[key] = pids_set;

  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_id) {
  HANDLE_TH_ERRORS

  int64_t key = THPUtils_unpackLong(loader_id);
  THPUtils_assert(worker_pids.find(key) != worker_pids.end(), "Cannot find worker "
        "information for DataLoader with id %ld.", key);

  worker_pids.erase(key);

  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

#undef SIGNAL_HANDLER

#else
// dummy implementations for windows

PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *_ignored) {
    Py_RETURN_TRUE;
}

PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *_ignored) {
    Py_RETURN_TRUE;
}

PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *_ignored) {
    Py_RETURN_NONE;
}

PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *_ignored) {
    Py_RETURN_NONE;
}

#endif

PyMethodDef DataLoaderMethods[] = {
  {"_set_worker_signal_handlers",  (PyCFunction)THPModule_setWorkerSignalHandlers,  METH_NOARGS,   NULL},
  {"_update_worker_pids",          (PyCFunction)THPModule_updateWorkerPIDs,         METH_VARARGS,  NULL},
  {"_remove_worker_pids",          (PyCFunction)THPModule_removeWorkerPIDs,         METH_O,        NULL},
  {"_error_if_any_worker_fails",   (PyCFunction)THPModule_errorIfAnyWorkerFails,    METH_NOARGS,   NULL},
  {NULL, NULL, 0, NULL}
};
